package org.zjvis.datascience.service.socket;

import com.alibaba.fastjson.JSONObject;
import com.corundumstudio.socketio.SocketIOClient;
import com.corundumstudio.socketio.SocketIOServer;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.constant.NoticeConstant;
import org.zjvis.datascience.common.enums.TaskInstanceStatus;
import org.zjvis.datascience.common.exception.BaseErrorCode;
import org.zjvis.datascience.common.model.ApiResultCode;
import org.zjvis.datascience.common.util.JwtUtil;
import org.zjvis.datascience.service.PipelineService;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @description WebSocket 服务层Service
 * @date 2021-11-15
 */
@Service
public class SocketIOService {
    private final static Logger logger = LoggerFactory.getLogger(SocketIOService.class);

    // 用来存已连接的客户端
    private static Map<String, Set<SocketIOClient>> clientsMap = new ConcurrentHashMap<>();

    @Autowired
    private SocketIOServer socketIOServer;

    @Lazy
    @Autowired
    private PipelineService pipelineService;

    /**
     * Spring IoC容器创建之后，在加载SocketIOServiceImpl Bean之后启动
     *
     * @throws Exception
     */
    @PostConstruct
    private void autoStartup() throws Exception {
        start();
    }

    @PreDestroy
    private void autoStop() throws Exception {
        stop();
    }

    public void start() {
        // 监听客户端连接
        socketIOServer.addConnectListener(client -> {
            logger.info("客户端 " + client.getRemoteAddress().toString() + " 创建连接，session id: " + client.getSessionId());

            // 给客户端发送一条信息 发送ClientReceive事件 需要客户端绑定此事件即可接收到消息
            JSONObject jsonObject = new JSONObject();
            jsonObject.put("name", "connect");
            jsonObject.put("message", "create connect success");
        });

        // 监听客户端断开连接
        socketIOServer.addDisconnectListener(client -> {
            logger.info("客户端 " + client.getRemoteAddress().toString() + " 连接中断,session id: " + client.getSessionId());
            client.disconnect();

            cleanClosedChannels();
        });

        // 处理自定义的事件，与连接监听类似
        // 此示例中测试的json收发 所以接收参数为JSONObject 如果是字符类型可以用String.class或者Object.class
        socketIOServer.addEventListener("register", JSONObject.class, (client, data, ackSender) -> {
            JSONObject jsonObject = data;
            JSONObject jo = new JSONObject();
            if (!jsonObject.containsKey("token")) {
                jo.put("message", BaseErrorCode.SOCKET_TOKEN_MISSING.getMsg());
                jo.put("code", BaseErrorCode.SOCKET_TOKEN_MISSING.getCode());
                client.sendEvent("register", jo);
                return;
            }
            String token = jsonObject.getString("token");
            boolean isTokenExpired = false;
            try {
                isTokenExpired = JwtUtil.isTokenExpired(token);
            } catch (Exception e) {
                logger.error(e.getMessage());
                jo.put("message", BaseErrorCode.SOCKET_TOKEN_USELESS.getMsg());
                jo.put("code", BaseErrorCode.SOCKET_TOKEN_USELESS.getCode());
                client.sendEvent("register", jo);
                return;
            }

            if (isTokenExpired) {
                //过期
                jo.put("message", BaseErrorCode.SOCKET_TOKEN_EXPIRE.getMsg());
                jo.put("code", BaseErrorCode.SOCKET_TOKEN_EXPIRE.getCode());
                client.sendEvent("register", jo);
                return;
            }

            Long userId = JwtUtil.getUserId(token);
            if (userId == null) {
                jo.put("message", BaseErrorCode.SOCKET_TOKEN_USELESS.getMsg());
                jo.put("code", BaseErrorCode.SOCKET_TOKEN_USELESS.getCode());
                client.sendEvent("register", jo);
                return;
            }

            logger.info("用户： " + userId + " 注册");

            addClient(userId.toString(), client);

            jo.put("message", ApiResultCode.SUCCESS.getMessage());
            jo.put("code", ApiResultCode.SUCCESS.getCode());
            client.sendEvent("register", jo);
        });

        socketIOServer.addEventListener(NoticeConstant.PIPELINE_QUERY_STATUS, JSONObject.class, (client, data, ackSender) -> {
            //为可视化构建提供的queryStatus WS
            logger.info(" queryStatus event 开始：");
            if (data.isEmpty() || !data.containsKey("pipelineId")) {
                logger.error("" + ApiResultCode.PARAM_ERROR.getMessage());
                return;
            }
            addClient(data.getString("pipelineId"), client);
            if (data.containsKey("sessionId")) {
                JSONObject statusObj = pipelineService.queryStatus(data.getString("sessionId"))
                        .getJSONArray("sessionLogs").getJSONObject(0);
                logger.warn("enter queryStatus -> {}", statusObj);
                while (statusObj.getString("status").equals(TaskInstanceStatus.CREATE.toString()) ||
                        statusObj.getString("status").equals(TaskInstanceStatus.RUNNING.toString())) {
                    statusObj = pipelineService.queryStatus(data.getString("sessionId"))
                            .getJSONArray("sessionLogs").getJSONObject(0);
                    Thread.sleep(1000);
                }
                logger.warn("send queryStatus event end..");
                client.sendEvent(NoticeConstant.PIPELINE_QUERY_STATUS, statusObj);
            }
        });

        socketIOServer.addEventListener(NoticeConstant.SOCKET_EVENT_PROJECT_USER_ROLE, JSONObject.class, (client, data, ackSender) -> {
            logger.info("event开始：" + NoticeConstant.SOCKET_EVENT_PROJECT_USER_ROLE);
            JSONObject jsonObject = data;
            if (data == null || data.isEmpty()) {
                logger.error("data is empty");
            } else {
                if (data.containsKey("event")) {
                    String event = jsonObject.getString("event");
                    if (StringUtils.isNotBlank(event)) {
                        client.sendEvent(NoticeConstant.SOCKET_EVENT_PROJECT_USER_ROLE, jsonObject);
                    }
                } else {
                    logger.error("no event");
                }
            }
        });

        socketIOServer.start();
        logger.info("socket.io初始化服务完成");
    }

    public void stop() {
        if (socketIOServer != null) {
            socketIOServer.stop();
            socketIOServer = null;
        }
        logger.info("socket.io服务已关闭");
    }

    /**
     * 添加client
     *
     * @param key
     * @param client
     */
    public void addClient(String key, SocketIOClient client) {
        Set<SocketIOClient> scs = clientsMap.get(key);
        if (scs == null) {
            logger.info("创建 " + key + " 列表,添加session：" + client.getSessionId());
            scs = new HashSet<>();
            scs.add(client);
            clientsMap.put(key, scs);
        } else if (!scs.contains(client)) {
            logger.info("为 " + key + " 列表添加新数据，session：" + client.getSessionId());
            scs.add(client);
        }
    }

    /**
     * 清除缓存中关闭的连接
     */
    private void cleanClosedChannels() {
        logger.info("清除缓存中关闭的连接");
        Set<String> keys = clientsMap.keySet();
        for (String key : keys) {
            Set<SocketIOClient> scs = clientsMap.get(key);
            if (scs == null) {
                continue;
            } else {
                Iterator<SocketIOClient> iterator = scs.iterator();
                while (iterator.hasNext()) {
                    SocketIOClient sc = iterator.next();
                    if (!sc.isChannelOpen()) {
                        iterator.remove();
                    }
                }
            }
        }
    }

    /**
     * 通知指定事件内容给指定用户
     *
     * @param userId 用户id
     * @param event  事件名
     * @param data   消息内容
     */
    public void sendToUser(String userId, String event, JSONObject data) {
        Set<SocketIOClient> socketIOClients = clientsMap.get(userId);
        if (socketIOClients != null) {
            for (SocketIOClient socketIOClient : socketIOClients) {
                if (socketIOClient.isChannelOpen()) {
                    logger.info("向" + userId + "发送" + event + "");
                    socketIOClient.sendEvent(event, data);
                } else {
                    logger.info("channel closed，不推送");
                }
            }
        } else {
            logger.info("[sendToUser] 获取不到client，不推送");
        }
    }

    /**
     * @param event
     * @param data
     */
    public void sendToDashboard(String event, JSONObject data) throws InterruptedException {
        Set<SocketIOClient> socketIOClients = clientsMap.get(data.getString("pipelineId"));
        if (socketIOClients != null) {
            for (SocketIOClient socketIOClient : socketIOClients) {
                if (socketIOClient.isChannelOpen()) {
                    logger.info("向" + data.getString("sessionId") + "发送" + event + " 内容" + data);
                    socketIOClient.sendEvent(event, data);
                } else {
                    logger.info("channel closed，不推送");
                }
            }
        } else {
            logger.info("[sendToDashboard] 获取不到client，不推送");
        }
    }
}
